{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "612b6a05",
   "metadata": {},
   "source": [
    "# Distributed Training and Inference with XGBoost and LightGBM on Ray\n",
    "\n",
    "<a id=\"try-anyscale-quickstart-xgboost_example\" href=\"https://console.anyscale.com/register/ha?render_flow=ray&utm_source=ray_docs&utm_medium=docs&utm_campaign=xgboost_example\">\n",
    "    <img src=\"../../../_static/img/run-on-anyscale.svg\" alt=\"try-anyscale-quickstart\">\n",
    "</a>\n",
    "<br></br>\n",
    "\n",
    "(train-gbdt-guide)=\n",
    "\n",
    "> **Note**: The API shown in this notebook is now deprecated. Please refer to the updated API in [Getting Started with Distributed Training using XGBoost](../../getting-started-xgboost.rst) instead.\n",
    "\n",
    "\n",
    "In this tutorial, you'll discover how to scale out data preprocessing, training, and inference with XGBoost and LightGBM on Ray.\n",
    "\n",
    "To run this tutorial, we need to install the following dependencies:\n",
    "\n",
    "```bash\n",
    "pip install -qU \"ray[data,train]\" xgboost lightgbm\n",
    "```\n",
    "\n",
    "Then, we need some imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5a2250e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "\n",
    "import pandas as pd\n",
    "import xgboost\n",
    "\n",
    "import ray\n",
    "from ray.data import Dataset, Preprocessor\n",
    "from ray.data.preprocessors import StandardScaler\n",
    "from ray.train import Checkpoint, CheckpointConfig, Result, RunConfig, ScalingConfig\n",
    "from ray.train.xgboost import XGBoostTrainer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ad88db8",
   "metadata": {},
   "source": [
    "Next we define a function to load our train, validation, and test datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "06b0f220",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:\n",
    "    \"\"\"Load and split the dataset into train, validation, and test sets.\"\"\"\n",
    "    dataset = ray.data.read_csv(\"s3://anonymous@air-example-data/breast_cancer.csv\")\n",
    "    train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)\n",
    "    test_dataset = valid_dataset.drop_columns([\"target\"])\n",
    "    return train_dataset, valid_dataset, test_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56e67eb1",
   "metadata": {},
   "source": [
    "## How to preprocess data for training?\n",
    "\n",
    "Preprocessing is a crucial step in preparing your data for training, especially for tabular datasets.\n",
    "Ray Data offers built-in preprocessors that simplify common feature preprocessing tasks especially for tabular data.\n",
    "These can be seamlessly integrated with Ray Datasets, allowing you to preprocess your data in a fault-tolerant and distributed way before training. Here's how:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f12ca633",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:30:44,905\tINFO worker.py:1841 -- Started a local Ray instance.\n",
      "2025-02-07 16:30:45,596\tINFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-07_16-30-44_167214_9631/logs/ray-data\n",
      "2025-02-07 16:30:45,596\tINFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AggregateNumRows[AggregateNumRows]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb0108523a6343808f4ce9e97a8c3f3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e13678df08ec4db48487b329c5c0ca43",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- ReadCSV->SplitBlocks(24) 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be4bf2621cde4711af9f18ccd59e2580",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- AggregateNumRows 2: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:30:46,367\tINFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-07_16-30-44_167214_9631/logs/ray-data\n",
      "2025-02-07 16:30:46,367\tINFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f855e16cb0e4be1a754dfd7f38687ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f65b593533424f75887168b33b6cf3fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- ReadCSV->SplitBlocks(24) 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:30:46,729\tINFO dataset.py:2704 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.\n",
      "2025-02-07 16:30:46,730\tINFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-07_16-30-44_167214_9631/logs/ray-data\n",
      "2025-02-07 16:30:46,730\tINFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "30c9df8a433641b8b70f2f8c58f8a455",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94191b314c144d2d90f5607a11880e83",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- Aggregate 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d36381e0aad4cce85126b25b1021ccf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sort Sample 2:   0%|          | 0.00/1.00 [00:00<?, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4310c898389f4371817967e79abcde9d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "348c283a0b0e4322a5726e8355661f2e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33872f0781c4426db5574549ef8dcb10",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- limit=1 5: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Load and split the dataset\n",
    "train_dataset, valid_dataset, _test_dataset = prepare_data()\n",
    "\n",
    "# pick some dataset columns to scale\n",
    "columns_to_scale = [\"mean radius\", \"mean texture\"]\n",
    "\n",
    "# Initialize the preprocessor\n",
    "preprocessor = StandardScaler(columns=columns_to_scale)\n",
    "# train the preprocessor on the training set\n",
    "preprocessor.fit(train_dataset)\n",
    "# apply the preprocessor to the training and validation sets\n",
    "train_dataset = preprocessor.transform(train_dataset)\n",
    "valid_dataset = preprocessor.transform(valid_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76b534fa",
   "metadata": {},
   "source": [
    "## Save and load XGBoost and LightGBM checkpoints\n",
    "\n",
    "Checkpointing is a powerful feature.\n",
    "It is particularly useful for long-running training sessions, as it enables you to resume training from the last checkpoint in case of interruptions.\n",
    "[`XGBoostTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.xgboost.XGBoostTrainer.html#ray.train.xgboost.XGBoostTrainer) and \n",
    "[`LightGBMTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.lightgbm.LightGBMTrainer.html#ray.train.lightgbm.LightGBMTrainer) both implement checkpointing out of the box. These checkpoints can be loaded into memory\n",
    "using static methods [`XGBoostTrainer.get_model`](https://docs.ray.io/en/latest/train/api/doc/ray.train.xgboost.XGBoostTrainer.get_model.html#ray.train.xgboost.XGBoostTrainer.get_model) and [`LightGBMTrainer.get_model`](https://docs.ray.io/en/latest/train/api/doc/ray.train.lightgbm.LightGBMTrainer.get_model.html#ray.train.lightgbm.LightGBMTrainer.get_model).\n",
    "\n",
    "The only required change is to configure [`CheckpointConfig`](https://docs.ray.io/en/latest/train/api/doc/ray.train.CheckpointConfig.html#ray.train.CheckpointConfig) to set the checkpointing frequency. For example, the following configuration\n",
    "saves a checkpoint on every boosting round and only keeps the latest checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9787bb14",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configure checkpointing to save progress during training\n",
    "run_config = RunConfig(\n",
    "    checkpoint_config=CheckpointConfig(\n",
    "        # Checkpoint every 10 iterations.\n",
    "        checkpoint_frequency=10,\n",
    "        # Only keep the latest checkpoint and delete the others.\n",
    "        num_to_keep=1,\n",
    "    )\n",
    "    ## If running in a multi-node cluster, this is where you\n",
    "    ## should configure the run's persistent storage that is accessible\n",
    "    ## across all worker nodes with `storage_path=\"s3://...\"`\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9887b8e",
   "metadata": {},
   "source": [
    "## Basic training with tree-based models in Train\n",
    "\n",
    "Just as in the original [`xgboost.train()`](https://xgboost.readthedocs.io/en/stable/parameter.html) and [`lightgbm.train()`](https://lightgbm.readthedocs.io/en/latest/Parameters.html) functions, the\n",
    "training parameters are passed as the `params` dictionary.\n",
    "\n",
    "### XGBoost Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a39765d7",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:32:31,783\tINFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== Status ==\n",
      "Current time: 2025-02-07 16:32:31 (running for 00:00:00.11)\n",
      "Using FIFO scheduling algorithm.\n",
      "Logical resource usage: 3.0/12 CPUs, 0/0 GPUs\n",
      "Result logdir: /tmp/ray/session_2025-02-07_16-30-44_167214_9631/artifacts/2025-02-07_16-32-31/XGBoostTrainer_2025-02-07_16-32-31/driver_artifacts\n",
      "Number of trials: 1/1 (1 PENDING)\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a4169c256994fcdb198d150ff7db46a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=10032) Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b5a1cac8205e488cbb56a2631c229c99",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=10032) - StandardScaler 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6e352b856124bf2a05d6aa8641c370f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=10032) - split(2, equal=True) 2: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50a5c3459f4a4fb38fcfa488b0051f34",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=10033) Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eadbed22fe9c4568b3527cbe4bec5c49",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=10033) - StandardScaler 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94e362696bbd4860b0a03a6b0ef0f224",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(pid=10033) - split(2, equal=True) 2: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:32:34,045\tWARNING experiment_state.py:206 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds and may become a bottleneck. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this warning by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0). Set it to 0 to completely suppress this warning.\n",
      "2025-02-07 16:32:34,105\tWARNING experiment_state.py:206 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds and may become a bottleneck. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
      "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
      "You can suppress this warning by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0). Set it to 0 to completely suppress this warning.\n",
      "2025-02-07 16:32:35,137\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/rdecal/ray_results/XGBoostTrainer_2025-02-07_16-32-31' in 0.0110s.\n",
      "2025-02-07 16:32:35,140\tINFO tune.py:1041 -- Total run time: 3.36 seconds (3.34 seconds for the tuning loop).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== Status ==\n",
      "Current time: 2025-02-07 16:32:35 (running for 00:00:03.35)\n",
      "Using FIFO scheduling algorithm.\n",
      "Logical resource usage: 3.0/12 CPUs, 0/0 GPUs\n",
      "Result logdir: /tmp/ray/session_2025-02-07_16-30-44_167214_9631/artifacts/2025-02-07_16-32-31/XGBoostTrainer_2025-02-07_16-32-31/driver_artifacts\n",
      "Number of trials: 1/1 (1 TERMINATED)\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set up the XGBoost trainer with the specified configuration\n",
    "trainer = XGBoostTrainer(\n",
    "    # see \"How to scale out training?\" for more details\n",
    "    scaling_config=ScalingConfig(\n",
    "        # Number of workers to use for data parallelism.\n",
    "        num_workers=2,\n",
    "        # Whether to use GPU acceleration. Set to True to schedule GPU workers.\n",
    "        use_gpu=False,\n",
    "    ),\n",
    "    label_column=\"target\",\n",
    "    num_boost_round=20,\n",
    "    # XGBoost specific params (see the `xgboost.train` API reference)\n",
    "    params={\n",
    "        \"objective\": \"binary:logistic\",\n",
    "        # uncomment this and set `use_gpu=True` to use GPU for training\n",
    "        # \"tree_method\": \"gpu_hist\",\n",
    "        \"eval_metric\": [\"logloss\", \"error\"],\n",
    "    },\n",
    "    datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n",
    "    # store the preprocessor in the checkpoint for inference later\n",
    "    metadata={\"preprocessor_pkl\": preprocessor.serialize()},\n",
    "    run_config=run_config,\n",
    ")\n",
    "result = trainer.fit()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b18221b",
   "metadata": {},
   "source": [
    "We can now view the model's metrics:\n",
    "\n",
    "```python\n",
    "print(result.metrics)\n",
    "```\n",
    "\n",
    "This should output something like:\n",
    "\n",
    "```\n",
    "{'train-logloss': 0.00587594546605992, 'train-error': 0.0, 'valid-logloss': 0.06215000962556052, 'valid-error': 0.02941176470588235, 'time_this_iter_s': 0.0101318359375, 'should_checkpoint': True, 'done': True, 'training_iteration': 101, 'trial_id': '40fed_00000', 'date': '2023-07-06_18-33-25', 'timestamp': 1688693605, 'time_total_s': 4.901317834854126, 'pid': 40725, 'hostname': 'Balajis-MacBook-Pro-16', 'node_ip': '127.0.0.1', 'config': {}, 'time_since_restore': 4.901317834854126, 'iterations_since_restore': 101, 'experiment_tag': '0'}\n",
    "```\n",
    "\n",
    ":::{tip} Once you enable checkpointing, you can follow [this guide](https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html#train-fault-tolerance) to enable fault tolerance. :::"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0838a4e6",
   "metadata": {},
   "source": [
    "## LightGBM Example\n",
    "\n",
    "Modifying this example to use LightGBM instead of XGBoost is straightforward. You just have to change the trainer class and the model-specific parameters:\n",
    "\n",
    "```diff\n",
    "- from ray.train.xgboost import XGBoostTrainer\n",
    "+ from ray.train.lightgbm import LightGBMTrainer\n",
    "\n",
    "- trainer = XGBoostTrainer(\n",
    "+ trainer = LightGBMTrainer(\n",
    "\n",
    "- \"objective\": \"binary:logistic\",\n",
    "+ \"objective\": \"binary\",\n",
    "- \"eval_metric\": [\"logloss\", \"error\"],\n",
    "+ \"metric\": [\"binary_logloss\", \"binary_error\"],\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7816f41",
   "metadata": {},
   "source": [
    "## Running inference with a trained tree-based model\n",
    "\n",
    "Now that we have a trained model, we can use it to make predictions on new data.\n",
    "Let's define a utility function to perform streaming and distributed batch inference with our trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0e9c4293",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Predict:\n",
    "    def __init__(self, checkpoint: Checkpoint):\n",
    "        self.model = XGBoostTrainer.get_model(checkpoint)\n",
    "        # extract the preprocessor from the checkpoint metadata\n",
    "        self.preprocessor = Preprocessor.deserialize(\n",
    "            checkpoint.get_metadata()[\"preprocessor_pkl\"]\n",
    "        )\n",
    "\n",
    "    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:\n",
    "        preprocessed_batch = self.preprocessor.transform_batch(batch)\n",
    "        dmatrix = xgboost.DMatrix(preprocessed_batch)\n",
    "        return {\"predictions\": self.model.predict(dmatrix)}\n",
    "\n",
    "\n",
    "def predict_xgboost(result: Result):\n",
    "    _, _, test_dataset = prepare_data()\n",
    "\n",
    "    scores = test_dataset.map_batches(\n",
    "        Predict,\n",
    "        fn_constructor_args=[result.checkpoint],\n",
    "        concurrency=1,\n",
    "        batch_format=\"pandas\",\n",
    "    )\n",
    "\n",
    "    predicted_labels = scores.map_batches(\n",
    "        lambda df: (df > 0.5).astype(int), batch_format=\"pandas\"\n",
    "    )\n",
    "    print(\"PREDICTED LABELS\")\n",
    "    predicted_labels.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21e21449",
   "metadata": {},
   "source": [
    "We can now get the predictions from the model on the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dc5222a0",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:30:52,878\tINFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-07_16-30-44_167214_9631/logs/ray-data\n",
      "2025-02-07 16:30:52,878\tINFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AggregateNumRows[AggregateNumRows]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d5caf740a9e646668c356738e2907e35",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df810cb770bd42ecb4cb94b99df90cef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- ReadCSV->SplitBlocks(24) 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7d734282408a4c19aff6b6fba7e75edc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- AggregateNumRows 2: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:30:53,241\tINFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-07_16-30-44_167214_9631/logs/ray-data\n",
      "2025-02-07 16:30:53,241\tINFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e756589fc8064083ad72a4ac185ceac4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b23c403389945b8a510d41d7e9b2f6c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- ReadCSV->SplitBlocks(24) 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 16:30:53,559\tINFO streaming_executor.py:108 -- Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-02-07_16-30-44_167214_9631/logs/ray-data\n",
      "2025-02-07 16:30:53,559\tINFO streaming_executor.py:109 -- Execution plan of Dataset: InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(drop_columns)->MapBatches(Predict)] -> TaskPoolMapOperator[MapBatches(<lambda>)] -> LimitOperator[limit=20]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PREDICTED LABELS\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e10e44782a64d629e1325becee70729",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running 0: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9aa1bee56bd4580b3809c406b041676",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- MapBatches(drop_columns)->MapBatches(Predict) 1: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01736a4584d94484bca10c62f917eb9a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- MapBatches(<lambda>) 2: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3456fd5c616149c09bbaccac9ec980d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "- limit=20 3: 0.00 row [00:00, ? row/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 0}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 0}\n",
      "{'predictions': 1}\n",
      "{'predictions': 0}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 0}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 1}\n",
      "{'predictions': 0}\n"
     ]
    }
   ],
   "source": [
    "predict_xgboost(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16c8ec6b",
   "metadata": {},
   "source": [
    "This should output something like:\n",
    "\n",
    "```\n",
    "PREDICTED LABELS\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 0}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 0}\n",
    "{'predictions': 1}\n",
    "{'predictions': 0}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 0}\n",
    "{'predictions': 0}\n",
    "{'predictions': 1}\n",
    "{'predictions': 1}\n",
    "{'predictions': 0}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f64200b",
   "metadata": {},
   "source": [
    "## How to scale out training?\n",
    "\n",
    "One of the key advantages of using Ray Train is its ability to effortlessly scale your training workloads.\n",
    "By adjusting the [`ScalingConfig`](https://docs.ray.io/en/latest/train/api/doc/ray.train.ScalingConfig.html#ray.train.ScalingConfig),\n",
    "you can optimize resource utilization and reduce training time, making it ideal for large-scale machine learning tasks.\n",
    "\n",
    ":::{note}\n",
    "Ray Train doesn’t modify or otherwise alter the working of the underlying XGBoost or LightGBM distributed training algorithms. Ray only provides orchestration, data ingest and fault tolerance. For more information on GBDT distributed training, refer to [XGBoost documentation](https://xgboost.readthedocs.io/en/stable/) and [LightGBM documentation](https://lightgbm.readthedocs.io/en/latest/).\n",
    ":::\n",
    "\n",
    "### Multi-node CPU Example\n",
    "\n",
    "Setup: 4 nodes with 8 CPUs each.\n",
    "\n",
    "Use-case: To utilize all resources in multi-node training.\n",
    "\n",
    "```python\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=4,\n",
    "    resources_per_worker={\"CPU\": 8},\n",
    ")\n",
    "```\n",
    "\n",
    "### Single-node multi-GPU Example\n",
    "\n",
    "Setup: 1 node with 8 CPUs and 4 GPUs.\n",
    "\n",
    "Use-case: If you have a single node with multiple GPUs, you need to use\n",
    "distributed training to leverage all GPUs.\n",
    "\n",
    "```python\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=4,\n",
    "    use_gpu=True,\n",
    ")\n",
    "```\n",
    "\n",
    "### Multi-node multi-GPU Example\n",
    "\n",
    "Setup: 4 nodes with 8 CPUs and 4 GPUs each.\n",
    "\n",
    "Use-case: If you have multiple nodes with multiple GPUs, you need to\n",
    "schedule one worker per GPU.\n",
    "\n",
    "```python\n",
    "scaling_config = ScalingConfig(\n",
    "    num_workers=16,\n",
    "    use_gpu=True,\n",
    ")\n",
    "```\n",
    "\n",
    "Note that you just have to adjust the number of workers. Ray handles everything else automatically.\n",
    "\n",
    "::: {warning}\n",
    "Specifying a *shared storage location* (such as cloud storage or NFS) is *optional* for single-node clusters, but it is **required for multi-node clusters**. Using a local path will [raise an error](https://docs.ray.io/en/latest/train/user-guides/persistent-storage.html#multinode-local-storage-warning) during checkpointing for multi-node clusters.\n",
    "\n",
    "```python\n",
    "trainer = XGBoostTrainer(\n",
    "    ..., run_config=ray.train.RunConfig(storage_path=\"s3://...\")\n",
    ")\n",
    "```\n",
    ":::"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31cded96",
   "metadata": {},
   "source": [
    "## How many remote actors should you use?\n",
    "\n",
    "This depends on your workload and your cluster setup. Generally there is no inherent benefit of running more than one remote actor per node for CPU-only training. This is because XGBoost can already leverage multiple CPUs with threading.\n",
    "\n",
    "However, in some cases, you should consider some starting more than one actor per node:\n",
    "\n",
    "For **multi GPU training**, each GPU should have a separate remote actor. Thus, if your machine has 24 CPUs and 4 GPUs, you want to start 4 remote actors with 6 CPUs and 1 GPU each\n",
    "\n",
    "In a **heterogeneous cluster**, you might want to find the [greatest common divisor](https://en.wikipedia.org/wiki/Greatest_common_divisor) for the number of CPUs. For example, for a cluster with three nodes of 4, 8, and 12 CPUs, respectively, you should set the number of actors to 6 and the CPUs per actor to 4.\n",
    "\n",
    "## How to use GPUs for training?\n",
    "\n",
    "Ray Train enables multi-GPU training for XGBoost and LightGBM. The core backends automatically leverage NCCL2 for cross-device communication. All you have to do is to start one actor per GPU and set GPU-compatible parameters. For example, XGBoost’s `tree_method` to `gpu_hist`. See XGBoost documentation for more details.\n",
    "\n",
    "For instance, if you have 2 machines with 4 GPUs each, you want to start 8 workers, and set `use_gpu=True`. There is usually no benefit in allocating less (for example, 0.5) or more than one GPU per actor.\n",
    "\n",
    "You should divide the CPUs evenly across actors per machine, so if your machines have 16 CPUs in addition to the 4 GPUs, each actor should have 4 CPUs to use.\n",
    "\n",
    "```python\n",
    "trainer = XGBoostTrainer(\n",
    "    scaling_config=ScalingConfig(\n",
    "        # Number of workers to use for data parallelism.\n",
    "        num_workers=2,\n",
    "        # Whether to use GPU acceleration.\n",
    "        use_gpu=True,\n",
    "    ),\n",
    "    params={\n",
    "        # XGBoost specific params\n",
    "        \"tree_method\": \"gpu_hist\",\n",
    "        \"eval_metric\": [\"logloss\", \"error\"],\n",
    "    },\n",
    "    ...\n",
    ")\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f04989d",
   "metadata": {},
   "source": [
    "## How to optimize XGBoost memory usage?\n",
    "\n",
    "XGBoost uses a compute-optimized data structure called `DMatrix` to store training data.\n",
    "However, converting a dataset to a `DMatrix` involves storing a complete copy of the data\n",
    "as well as intermediate conversions.\n",
    "On a 64-bit system the format is 64-bit floats. Depending on the system and original dataset dtype, \n",
    "this matrix can thus occupy more memory than the original dataset.\n",
    "\n",
    "The **peak memory usage** for CPU-based training is at least 3x the dataset size, assuming dtype `float32` on a 64-bit system, plus about **400,000 KiB** for other resources, like operating system requirements and storing of intermediate results.\n",
    "\n",
    "### Example\n",
    "\n",
    "- Machine type: AWS m5.xlarge (4 vCPUs, 16 GiB RAM)\n",
    "- Usable RAM: ~15,350,000 KiB\n",
    "- Dataset: 1,250,000 rows with 1024 features, dtype float32. Total size: 5,000,000 KiB\n",
    "- XGBoost DMatrix size: ~10,000,000 KiB\n",
    "\n",
    "This dataset fits exactly on this node for training.\n",
    "\n",
    "Note that the DMatrix size might be lower on a 32 bit system.\n",
    "\n",
    "### GPUs\n",
    "\n",
    "Generally, the same memory requirements exist for GPU-based training. Additionally, the GPU must have enough memory to hold the dataset.\n",
    "\n",
    "In the preceding example, the GPU must have at least 10,000,000 KiB (about 9.6 GiB) memory. However, empirical data shows that using a `DeviceQuantileDMatrix` seems to result in more peak GPU memory usage, possibly for intermediate storage when loading data (about 10%).\n",
    "\n",
    "### Best practices\n",
    "\n",
    "In order to reduce peak memory usage, consider the following suggestions:\n",
    "\n",
    "- Store data as `float32` or less. You often don’t need more precision is often, and keeping data in a smaller format helps reduce peak memory usage for initial data loading.\n",
    "- Pass the `dtype` when loading data from CSV. Otherwise, floating point values are loaded as `np.float64` per default, increasing peak memory usage by 33%."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ray",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
